{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6bYaCABobL5q"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "FlUw7tSKbtg4"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61dp4Hg5ksTC"
      },
      "source": [
        "# Reuse pre-existing TF1.x checkpoints\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/migrate/reuse_checkpoints\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/migrate/reuse_checkpoints.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/migrate/reuse_checkpoints.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/migrate/reuse_checkpoints.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "avuMwzscPnHh"
      },
      "source": [
        "As seen in the above testing, `get_variable` calls in decorators for `tf.Module` and `tf.keras.layers.Layer` methods end up maintaining the variable naming semantics of `get_variable` in TF1.x graphs and sessions.\n",
        "\n",
        "Because [`tf.train.Checkpoint`](https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint) supports loading legacy `tf.compat.v1.train.Saver`-generated name-based checkpoints and your variable names are likely to match, you can use your pre-existing TF1.x name-based checkpoints often entirely out of the box."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TaYgaekzOAHf"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9kv9SmyVjGLA"
      },
      "source": [
        "## Loading checkpoints with no assignment mapping\n",
        "\n",
        "If your TF1.x program loaded checkpoints directly with no assignment mapping to map from saved to loaded variable names (such as when warm-starting), and your variable names from `tf.keras.layers.Layer` successfully matched the variable names exactly in the above tests, then you can just load your existing checkpoint directly with `tf.train.Checkpoint`.\n",
        "\n",
        "The `tf.train.Checkpoint` [guide](https://www.tensorflow.org/guide/checkpoint) describes APIs which you can use to validate that any and all variables you expect to have been loaded are actually loaded. Remember that `get_variable` weights will only get created when your model's forward pass runs.\n",
        "\n",
        "After loading the checkpoint and running it, consider saving out a new `tf.train.Checkpoint` that uses the TF2 object-oriented style and reusing that one. You should save out the TF2 object-oriented style checkpoint once your variables have been created."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NkUQJUUyjOJz"
      },
      "source": [
        "## Loading checkpoints with a simple nested variable_scope assignment map\n",
        "\n",
        "TF1.x code that used checkpoints for warm-starting just the backbone of a model often relies on an *assignment map*, such as the `assignment_map` argument of `tf.compat.v1.train.init_from_checkpoint`.\n",
        "\n",
        "If the assignment map rule is a simple one like `{\"/\": \"backbone_name/\"}` (which tries to map any variable named \"FOO\" in your checkpoint to the corresponding variable named \"backbone_name/FOO\" in your model), then it may be possible to either add or remove a `variable_scope` in your `track_tf1_style_variable`-decorated call method to make sure the variable names exactly match the names in your checkpoint with no need for an assignment map.\n",
        "\n",
        "Note: As described in past sections you still need to nest `tf.compat.v1.layers` that rely on autogenerated names inside at least one `variable_scope` to ensure variables get re-used correctly.\n",
        "\n",
        "You can then load your existing checkpoint as is with `tf.train.Checkpoint`, and save out a new object-oriented checkpoint right away."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J7NPybZVi_Ew"
      },
      "source": [
        "## Loading checkpoints with a more complicated assignment mapping\n",
        "\n",
        "If adjusting the `variable_scope` nesting in your shim-decorated call method proves insufficient, or you have a more complicated name and scope assignment mapping, you may need to directly rename the variables in your saved TF1.x name-based checkpoint to the expected variable names post-mapping. There are several approaches to do this like using a dictionary mapping that contains all variable names, using regular expressions to rewrite any scopes, etc. The following code snippets demonstrate a few of these methods."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RLpGN0kFy4OZ"
      },
      "source": [
        "This first code snippet shows how you can programmatically grab the names from an existing name-based checkpoint.\n",
        "```\n",
        "CKPT = \"/PATH/TO/ORIGINAL/CHECKPOINT\"\n",
        "with tf.Graph().as_default():\n",
        " checkpoint_reader = tf.train.NewCheckpointReader(CKPT)\n",
        " names_to_shapes = checkpoint_reader.get_variable_to_shape_map()\n",
        " var_names = sorted(list(names_to_shapes))\n",
        " for name in var_names:\n",
        "   print(name)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5qrjLapzsJ1K"
      },
      "source": [
        "This second code snippet shows how you can rename variables in one TF1.x name-based checkpoint and save out a new name-based checkpoint (if you have a dict `new_name_from_old_name` mapping old variable names to post-assignment-mapping variable names).\n",
        "```\n",
        "to_migrate = [\n",
        "    (\"/PATH/TO/ORIGINAL/CHECKPOINT\",\n",
        "     \"/PATH/TO/CHECKPOINT/WITH/RENAMED/VARIABLES\"),\n",
        "]\n",
        "for input_ckpt, output_ckpt in to_migrate:\n",
        "  print(f'Migrating {input_ckpt} to {output_ckpt}')\n",
        "  with tf.Graph().as_default():\n",
        "    checkpoint_reader = tf.train.NewCheckpointReader(input_ckpt)\n",
        "    variable_from_new_name = {}\n",
        "  \n",
        "    for old_name, new_name in new_name_from_old_name.items():\n",
        "      value = checkpoint_reader.get_tensor(old_name)\n",
        "      variable_from_new_name[new_name] = tf.Variable(value)\n",
        "  \n",
        "    init = tf.global_variables_initializer()\n",
        "    saver = tf.train.Saver(variable_from_new_name)\n",
        "  \n",
        "    with tf.Session() as sess:\n",
        "      sess.run(init)\n",
        "      saver.save(sess, output_ckpt, write_meta_graph=False)\n",
        "\n",
        "  for old_name, new_name in new_name_from_old_name.items():\n",
        "    print(f'{old_name} -> {new_name}')\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ejem9qLuoWp_"
      },
      "source": [
        "Once you have remapped the names in your variable name checkpoint according to the assignment map, you should then be able to load it and afterwards save out a new object-oriented checkpoint."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZDYSFob4jazB"
      },
      "source": [
        "## Loading checkpoints when not all variables are created by `get_variable`\n",
        "\n",
        "If your shim-decorated layer or module consists of some variables (or Keras layers/models) that use `tf.Variable` instead of `tf.compat.v1.get_variable` and get attached as properties/tracked in an object oriented way, they may have different variable naming semantics in TF1.x graphs/sessions versus during eager exection.\n",
        "\n",
        "So, you will have to compare these variable names before and after and explicitly rewrite the checkpoint with renamed variables, just as for more complex assignment mapping above.\n",
        "\n",
        "Once you do that you can load the checkpoint, and save out a new object-oriented checkpoint for future use.\n",
        "\n",
        "Warning: Variables may have duplicate names in eager execution, which may cause problems if multiple variables in the name-based checkpoint need to be mapped to the same name. You may be able to explicitly adjust the layer and variable names using `tf.name_scope` and layer constructor or `tf.Variable` `name` arguments to adjust variable names and ensure there are no duplicates."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JBMfArLQ0jb-"
      },
      "source": [
        "## Resources\n",
        "\n",
        "Refer to the following guides:\n",
        "* Validating numerical equivalence and correctness\n",
        "* TF Modeling Shims"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "reuse_checkpoints.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
